import java.util.HashMap;
import java.util.Map;

public class Solution525 {
    public int[] getDiffArr(int[] nums){
        int[] diffArr=new int[nums.length];
        int diff=0;
        for (int i = 0; i < nums.length; i++) {
            int tmpNum=nums[i];
            if(tmpNum==1){
                diff++;
            }
            else{
                diff--;
            }
            diffArr[i]=diff;
        }
        return diffArr;
    }


    public int findMaxLength(int[] nums) {
        int result=0;
        int[] diffArr=getDiffArr(nums);
        Map<Integer,Integer> mark=new HashMap<>();
        mark.put(0,-1);
        for (int i = 0; i < nums.length; i++) {
            int tmpDiff=diffArr[i];
            int ndLoc=mark.getOrDefault(tmpDiff,i);
            if(ndLoc!=i){
                result=Math.max(result,i-ndLoc);
            }
            else{
                mark.put(tmpDiff,i);
            }
        }
        return result;
    }

    public static void main(String[] args) {
        Solution525 s=new Solution525();
        System.out.println(s.findMaxLength(new int[]{0,1,1}));
    }
}
